Initial version porting the eden configs over to the new evo2 recipe#1502
Initial version porting the eden configs over to the new evo2 recipe#1502
Conversation
Signed-off-by: John St. John <jstjohn@nvidia.com>
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR introduces support for Eden (Llama 3.1) model variants alongside the existing Hyena SSM models in the Evo2 framework. New checkpoint conversion utilities enable Savanna-to-MBridge and MBridge-to-Vortex transformations, updated runtime scripts support model-type branching, and CLI entry points are added for checkpoint operations. Comprehensive test coverage validates model providers, roundtrip conversions, and inference/training workflows for both architectures. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant Savanna as Savanna<br/>Checkpoint
participant Conv1 as savanna_to_mbridge<br/>Converter
participant MBridge as Megatron Bridge<br/>DCP Checkpoint
participant Conv2 as mbridge_to_vortex<br/>Exporter
participant Vortex as Vortex<br/>Format
User->>Conv1: savanna_to_mbridge()<br/>(savanna_path, model_size)
Conv1->>Savanna: load_savanna_state_dict()
Savanna-->>Conv1: state_dict
Conv1->>Conv1: select model_provider<br/>from MODEL_OPTIONS
Conv1->>Conv1: savanna_to_mbridge_state_dict()<br/>(apply pattern mapping)
Conv1->>Conv1: package_mbridge_checkpoint()<br/>(write DCP structure)
Conv1-->>MBridge: checkpoint written
MBridge-->>User: output_path
User->>Conv2: mbridge_to_vortex()<br/>(mbridge_dir, model_size)
Conv2->>MBridge: load_mbridge_state_dict()
MBridge-->>Conv2: state_dict
Conv2->>Conv2: select HyenaModelProvider<br/>from HYENA_MODEL_OPTIONS
Conv2->>Conv2: mbridge_to_vortex_state_dict()<br/>(per-layer conversion:<br/>embedding, decoder blocks,<br/>final norm)
Conv2->>Conv2: _convert_hyena_layer()<br/>or _convert_attention_layer()
Conv2->>Conv2: _convert_mlp()
Conv2->>Conv2: _build_vortex_config()
Conv2-->>Vortex: .pt + config.json
Vortex-->>User: export complete
sequenceDiagram
actor User
participant Train as train.py<br/>Pretraining
participant Infer as infer_model_type()<br/>Classifier
participant Provider as Model Provider<br/>(Hyena or Eden)
participant FwdStep as Forward Step<br/>Function
User->>Train: launch with --model-size
Train->>Infer: infer_model_type(model_size)
alt model_size in HYENA_MODEL_OPTIONS
Infer-->>Train: "hyena"
Train->>Provider: HyenaModelProvider
Train->>FwdStep: select hyena_forward_step
else model_size in EDEN_MODEL_OPTIONS
Infer-->>Train: "eden"
Train->>Provider: EdenModelProvider
Train->>FwdStep: select gpt_forward_step
end
alt model_type != "hyena" && fp32_residual_connection
Train->>Train: disable fp32_residual_connection
end
Train->>FwdStep: launch pretraining<br/>with forward_step_fn
FwdStep-->>Train: loss, gradients
Train-->>User: training complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Signed-off-by: John St. John <jstjohn@nvidia.com>
…tions Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 9
🧹 Nitpick comments (4)
bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py (1)
714-805: Use the sourceinfer.pylaunch path for the Eden subprocesses.This file already introduced
_infer_script_path()and thePYTHONPATHprepend so localinfer.pyfixes are exercised without reinstalling. These new Eden cases still launch-m bionemo.evo2.run.infer, so a non-editable test environment can end up validating the installed package instead of this PR.Suggested direction
- "-m", - "bionemo.evo2.run.infer", + str(_infer_script_path()), "--ckpt-dir", str(mbridge_eden_checkpoint_path), ... ] env = copy.deepcopy(PRETEST_ENV) + src_dir = str(_recipe_root() / "src") + env["PYTHONPATH"] = src_dir + os.pathsep + env.get("PYTHONPATH", "")Apply the same launch pattern in
test_infer_eden_deterministic().🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py` around lines 714 - 805, The Eden tests are launching the installed module (-m bionemo.evo2.run.infer) instead of the local infer.py; update test_infer_eden_deterministic to use the same launch pattern as test_infer_eden_runs: replace the "-m bionemo.evo2.run.infer" style invocation with the script path from _infer_script_path() (use that function to build the cmd entry) and ensure the env prepends PYTHONPATH the same way (reuse PRETEST_ENV modification logic used in test_infer_eden_runs) so the local source infer.py is executed; reference test_infer_eden_deterministic, test_infer_eden_runs, _infer_script_path, and PRETEST_ENV to locate where to apply the changes.bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.py (1)
63-63: Consider addingweights_only=Trueor explicitweights_only=Falsetotorch.load.
torch.loadwithoutweights_onlywill default toFalseand emit a deprecation warning in recent PyTorch versions. Since these are locally-generated prediction files (tensors only),weights_only=Trueshould work and is safer.Suggested fix
- preds = [torch.load(pf) for pf in pred_files] + preds = [torch.load(pf, weights_only=True) for pf in pred_files]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.py` at line 63, The list comprehension using torch.load(pf) to build preds should pass an explicit weights_only argument to avoid the deprecation warning and ensure correct behavior for tensor-only prediction files; update the comprehension that references pred_files and preds to call torch.load(pf, weights_only=True) (or weights_only=False if non-weight objects are expected) so loading is explicit and future-proof.bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py (2)
132-132:weights_only=Falseis necessary but has security implications.This is required for loading Savanna checkpoints that may contain non-tensor data, but it can execute arbitrary code via pickle when loading untrusted files. Consider adding a comment noting this:
Suggested documentation
- raw = torch.load(str(path), map_location="cpu", weights_only=False) + # Note: weights_only=False is required for Savanna checkpoints containing custom objects. + # Only load checkpoints from trusted sources. + raw = torch.load(str(path), map_location="cpu", weights_only=False)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py` at line 132, Add a brief inline comment next to the torch.load call (the line setting raw = torch.load(str(path), map_location="cpu", weights_only=False)) explaining that weights_only=False is required to load Savanna checkpoints containing non-tensor data but is unsafe for untrusted files because it uses pickle and can execute arbitrary code; state that callers must ensure the path is trusted (or sanitize/validate inputs) before loading.
95-96: Bareexcept Exceptionis too broad.This catches all exceptions including unexpected ones like
KeyboardInterrupt(actuallyBaseException, butExceptionstill catches many things). Consider catching specific HuggingFace Hub exceptions:Suggested fix
- except Exception: + except (huggingface_hub.errors.EntryNotFoundError, huggingface_hub.errors.RepositoryNotFoundError): logger.warning(f"Single-file download failed for {repo_id}, trying multi-part shards...")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py` around lines 95 - 96, The bare except in the single-file download block (the except Exception: that logs "Single-file download failed for {repo_id}, trying multi-part shards...") is too broad; replace it by catching specific HF and network-related exceptions (for example huggingface_hub.exceptions.RepositoryNotFoundError, huggingface_hub.utils.entry_not_found_error or RevisionNotFoundError, and requests.exceptions.HTTPError/ConnectionError) and bind the exception (except (RepositoryNotFoundError, RevisionNotFoundError, HTTPError, ConnectionError) as e:) so you can log the actual error (include exc_info or str(e)) while allowing other unexpected exceptions to propagate (re-raise or don’t catch them). Locate the except block around the single-file download attempt and update the except clause and logger call accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@bionemo-recipes/recipes/evo2_megatron/README.md`:
- Around line 157-162: Add a language identifier (bash) to the fenced code
blocks containing the shell commands (e.g., the blocks that show
evo2_export_mbridge_to_vortex and evo2_convert_savanna_to_mbridge) so
markdownlint stops warning; locate the backtick fences around those command
examples and change the opening fence from ``` to ```bash for each occurrence
(including the second block around the evo2_convert_savanna_to_mbridge example).
In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/eden_provider.py`:
- Around line 130-139: The patch_eden_tokenizer function is defined but never
used at runtime; either remove this dead function and its export (and update the
unit test to use the runtime implementation) or integrate it into the recipes
tokenizer flow by importing and calling patch_eden_tokenizer immediately after
the tokenizer is constructed in predict.py (so the tokenizer uses BOS=1, EOS=2,
SEP=3, PAD=0); also ensure any exported symbol lists are updated to remove the
orphaned function if you delete it and avoid duplicating functionality already
present in the other package-level patch implementation.
In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py`:
- Around line 1115-1132: The bug: MODEL_OPTIONS is built as
{**HYENA_MODEL_OPTIONS, **EDEN_MODEL_OPTIONS} which lets EDEN override HYENA on
key collision, but infer_model_type checks HYENA first causing inconsistent
behavior; fix by adding a runtime collision check after constructing
MODEL_OPTIONS that computes collisions = set(HYENA_MODEL_OPTIONS) &
set(EDEN_MODEL_OPTIONS) and either raise a clear ValueError (or log and resolve
to a chosen precedence) if collisions is non-empty, and update infer_model_type
to rely on MODEL_OPTIONS (or document the chosen precedence) so behavior is
consistent; also update infer_model_type's docstring to include an Args section
describing the model_size parameter.
- Around line 656-699: Hyena20bModelProvider defines an incorrect/unused
attribute short_conv_len and an orphan hyena_out_proj_bias; remove
short_conv_len (it duplicates/typoed counterpart hyena_short_conv_len) and
delete hyena_out_proj_bias unless you add a corresponding field in HyenaConfig
and wire it into the model code; update Hyena20bModelProvider by removing the
short_conv_len and hyena_out_proj_bias declarations (or rename short_conv_len to
hyena_short_conv_len only if HyenaConfig lacks that field and you also add it
there).
In `@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py`:
- Around line 334-336: You're directly mutating the private attribute
model_provider._pg_collection with
ProcessGroupCollection.use_mpu_process_groups(), which is fragile; instead
expose and use a public API or setter on the provider (e.g., add or call a
method like set_process_group_collection or a constructor/init parameter on the
ModelProvider class) so non-Hyena models can be configured without touching
internals—update the provider implementation to accept and store the
ProcessGroupCollection via that public method and replace the direct assignment
at the call site with the new setter or init call, keeping Hyena models'
internal behavior unchanged.
In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/mbridge_to_vortex.py`:
- Around line 133-167: The export currently silently omits required tensors
(e.g., embedding.word_embeddings.weight, decoder.final_norm.weight and per-layer
weights produced by _convert_hyena_layer, _convert_attention_layer, and
_convert_mlp) so add a validation pass after the loop that defines the mandatory
target keys (embedding_layer.weight, unembed.weight, norm.scale and all expected
per-layer keys derived from prefix/block_prefix and the pattern) and check their
presence in the resulting vortex_sd (or mbridge_state_dict if conversions expect
source keys); collect missing keys into a list and raise a descriptive exception
listing layer and key names (including references to the layer index and symbol)
before writing the .pt/config.json to fail fast on bad --model-size or --no-te
choices.
- Around line 48-56: The current logic assumes mbridge_ckpt_dir is the
checkpoint root and fails if the user passes an iter_* directory; modify the
resolver around latest_file/iter_dir so that if mbridge_ckpt_dir.name matches
the iter_* pattern (e.g., startswith "iter_" or matches r"^iter_\d+$") you treat
mbridge_ckpt_dir itself as the iter_dir; otherwise keep the existing flow (check
for latest_checkpointed_iteration.txt, parse iteration into iter_{:07d}, or
fallback to glob("iter_*")). Update uses of iteration/iter_dirs to reflect this
early-path selection so valid direct iter_* paths are accepted.
In
`@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_checkpoint_roundtrip.py`:
- Around line 39-79: The fixtures download mutable HuggingFace checkpoints and
allow unsafe pickle loading; update savanna_checkpoint_path and
vortex_reference_path to pass explicit immutable commit SHAs via the revision=
parameter in their hf_hub_download(...) calls (use the specific commit SHA for
SAVANNA_1B_REPO and VORTEX_1B_REPO to pin the golden data), and modify
vortex_reference_sd to load the reference safely by calling torch.load(...,
map_location="cpu", weights_only=True) (or prefer safetensors if a .safetensors
artifact exists) so remote .pt files cannot execute pickle code.
In
`@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.py`:
- Around line 148-186: The test currently computes original_preds via
_run_predict(eden_ckpt, ...) but then only compares original_hf and
reimported_hf logits; change the test to run _run_predict on the roundtripped HF
checkpoint (use hf_reimported_dir) to produce hf_preds and then compare
original_preds["log_probs_seqs"] to hf_preds["log_probs_seqs"] (or the
equivalent per-token log-prob key) using a numeric assert (e.g.,
torch.testing.assert_close or numpy.testing.assert_allclose) so the comparison
actually verifies the roundtrip predictions; locate and update the block that
currently creates original_hf/reimported_hf and the final
torch.testing.assert_close to instead call _run_predict for hf_reimported_dir
(or both HF and eden if you want) and perform the assertion on the
"log_probs_seqs" entries.
---
Nitpick comments:
In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py`:
- Line 132: Add a brief inline comment next to the torch.load call (the line
setting raw = torch.load(str(path), map_location="cpu", weights_only=False))
explaining that weights_only=False is required to load Savanna checkpoints
containing non-tensor data but is unsafe for untrusted files because it uses
pickle and can execute arbitrary code; state that callers must ensure the path
is trusted (or sanitize/validate inputs) before loading.
- Around line 95-96: The bare except in the single-file download block (the
except Exception: that logs "Single-file download failed for {repo_id}, trying
multi-part shards...") is too broad; replace it by catching specific HF and
network-related exceptions (for example
huggingface_hub.exceptions.RepositoryNotFoundError,
huggingface_hub.utils.entry_not_found_error or RevisionNotFoundError, and
requests.exceptions.HTTPError/ConnectionError) and bind the exception (except
(RepositoryNotFoundError, RevisionNotFoundError, HTTPError, ConnectionError) as
e:) so you can log the actual error (include exc_info or str(e)) while allowing
other unexpected exceptions to propagate (re-raise or don’t catch them). Locate
the except block around the single-file download attempt and update the except
clause and logger call accordingly.
In `@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py`:
- Around line 714-805: The Eden tests are launching the installed module (-m
bionemo.evo2.run.infer) instead of the local infer.py; update
test_infer_eden_deterministic to use the same launch pattern as
test_infer_eden_runs: replace the "-m bionemo.evo2.run.infer" style invocation
with the script path from _infer_script_path() (use that function to build the
cmd entry) and ensure the env prepends PYTHONPATH the same way (reuse
PRETEST_ENV modification logic used in test_infer_eden_runs) so the local source
infer.py is executed; reference test_infer_eden_deterministic,
test_infer_eden_runs, _infer_script_path, and PRETEST_ENV to locate where to
apply the changes.
In
`@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.py`:
- Line 63: The list comprehension using torch.load(pf) to build preds should
pass an explicit weights_only argument to avoid the deprecation warning and
ensure correct behavior for tensor-only prediction files; update the
comprehension that references pred_files and preds to call torch.load(pf,
weights_only=True) (or weights_only=False if non-weight objects are expected) so
loading is explicit and future-proof.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 77786c27-73d0-4f42-b35a-c5b97f7a217b
📒 Files selected for processing (22)
bionemo-recipes/recipes/evo2_megatron/README.mdbionemo-recipes/recipes/evo2_megatron/examples/fine-tuning-tutorial.ipynbbionemo-recipes/recipes/evo2_megatron/examples/zeroshot_brca1.ipynbbionemo-recipes/recipes/evo2_megatron/pyproject.tomlbionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/eden_provider.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/predict.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/README.mdbionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/mbridge_to_vortex.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/nemo2_to_mbridge.pybionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/_eden_roundtrip_helper.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_predict.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_train.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_checkpoint_roundtrip.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.pybionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py
| ``` | ||
| evo2_export_mbridge_to_vortex \ | ||
| --mbridge-ckpt-dir /path/to/mbridge/iter_0000001 \ | ||
| --output-path /path/to/output/model_vortex.pt \ | ||
| --model-size evo2_1b_base | ||
| ``` |
There was a problem hiding this comment.
Add a language tag to these new fenced command blocks.
markdownlint will keep warning on both fences until they use a language identifier. bash fits the surrounding shell examples.
💡 Minimal fix
-```
+```bash
evo2_export_mbridge_to_vortex \
--mbridge-ckpt-dir /path/to/mbridge/iter_0000001 \
--output-path /path/to/output/model_vortex.pt \
--model-size evo2_1b_base@@
- +bash
Step 1: Savanna -> MBridge
evo2_convert_savanna_to_mbridge
--savanna-ckpt-path arcinstitute/savanna_evo2_1b_base
--mbridge-ckpt-dir /tmp/mbridge_1b
--model-size evo2_1b_base \
</details>
Also applies to: 181-194
<details>
<summary>🧰 Tools</summary>
<details>
<summary>🪛 markdownlint-cli2 (0.21.0)</summary>
[warning] 157-157: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
</details>
</details>
<details>
<summary>🤖 Prompt for AI Agents</summary>
Verify each finding against the current code and only fix it if needed.
In @bionemo-recipes/recipes/evo2_megatron/README.md around lines 157 - 162, Add
a language identifier (bash) to the fenced code blocks containing the shell
commands (e.g., the blocks that show evo2_export_mbridge_to_vortex and
evo2_convert_savanna_to_mbridge) so markdownlint stops warning; locate the
backtick fences around those command examples and change the opening fence from
tobash for each occurrence (including the second block around the
evo2_convert_savanna_to_mbridge example).
</details>
<!-- fingerprinting:phantom:medusa:grasshopper -->
<!-- This is an auto-generated comment by CodeRabbit -->
bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/eden_provider.py
Show resolved
Hide resolved
| @dataclass | ||
| class Hyena20bModelProvider(HyenaModelProvider): | ||
| """Config matching the Evo2 20B 1M context model (arcinstitute/evo2_20b). | ||
|
|
||
| Source: evo2/configs/evo2-20b-1m.yml from ARC's evo2 repo. | ||
| Layer pattern derived from: hcs=[0,4,7,11,14,18,21], hcm=[1,5,8,12,15,19,22], | ||
| hcl=[2,6,9,13,16,20,23], attn=[3,10,17]. | ||
| """ | ||
|
|
||
| hybrid_override_pattern: str = "SDH*SDHSDH*SDHSDH*SDHSDH" | ||
| num_layers: int = 24 | ||
| seq_length: int = 1_048_576 | ||
| hidden_size: int = 8192 | ||
| num_groups_hyena: int = 8192 | ||
| num_groups_hyena_medium: int = 512 | ||
| num_groups_hyena_short: int = 512 | ||
| make_vocab_size_divisible_by: int = 8 | ||
| tokenizer_library: str = "byte-level" | ||
| mapping_type: str = "base" | ||
| ffn_hidden_size: int = 22528 | ||
| gated_linear_unit: bool = True | ||
| num_attention_heads: int = 64 | ||
| use_cpu_initialization: bool = False | ||
| hidden_dropout: float = 0.0 | ||
| attention_dropout: float = 0.0 | ||
| params_dtype: torch.dtype = torch.bfloat16 | ||
| normalization: str = "RMSNorm" | ||
| add_qkv_bias: bool = False | ||
| add_bias_linear: bool = False | ||
| layernorm_epsilon: float = 1e-6 | ||
| recompute_granularity: str = "full" | ||
| recompute_method: str = "uniform" | ||
| recompute_num_layers: int = 4 | ||
| hyena_init_method: str = "small_init" | ||
| hyena_output_layer_init_method: str = "wang_init" | ||
| hyena_filter_no_wd: bool = True | ||
| rotary_base: int = 1_000_000 | ||
| seq_len_interpolation_factor: float = 128 | ||
| hyena_medium_conv_len: int = 128 | ||
| short_conv_len: int = 7 | ||
| hyena_short_conv_len: int = 3 | ||
| add_attn_proj_bias: bool = True | ||
| hyena_out_proj_bias: bool = True | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if these attributes are used in the HyenaModel or related code
rg -n "hyena_medium_conv_len|short_conv_len|hyena_short_conv_len|hyena_out_proj_bias" --type pyRepository: NVIDIA/bionemo-framework
Length of output: 4294
🏁 Script executed:
# Check the HyenaModelProvider parent class definition
rg -n "class HyenaModelProvider" -A 20 --type pyRepository: NVIDIA/bionemo-framework
Length of output: 2633
🏁 Script executed:
# Check how these attributes are used in the provide() method
sed -n '360,405p' bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.pyRepository: NVIDIA/bionemo-framework
Length of output: 2802
🏁 Script executed:
# Check HyenaConfig definition to see parent attributes
cat -n bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_config.py | head -150Repository: NVIDIA/bionemo-framework
Length of output: 5210
🏁 Script executed:
# Check _HyenaConfigForFlops and how it's initialized in provide()
sed -n '340,370p' bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.pyRepository: NVIDIA/bionemo-framework
Length of output: 1868
🏁 Script executed:
# Search for hyena_out_proj_bias definition and usage
rg -n "hyena_out_proj_bias" -B 2 -A 2 --type pyRepository: NVIDIA/bionemo-framework
Length of output: 591
🏁 Script executed:
# Check if short_conv_len is referenced or if it should be something else
rg -n "short_conv_len" --type pyRepository: NVIDIA/bionemo-framework
Length of output: 2270
Critical naming error and unused attribute in Hyena20bModelProvider.
Line 695 defines short_conv_len: int = 7, which appears nowhere else in the codebase and conflicts with the consistent naming pattern used throughout (e.g., hyena_short_conv_len at line 696 and in HyenaConfig). This attribute is likely a naming error and should either be removed or renamed to match the standard attribute name.
Additionally, line 698 defines hyena_out_proj_bias: bool = True, which has no corresponding definition in HyenaConfig and is never referenced in any model code. This orphaned attribute should be removed unless it serves a documented purpose.
The attributes hyena_medium_conv_len and hyena_short_conv_len are valid overrides of HyenaConfig defaults and are used in the model's FLOPs calculation, but the above two attributes need clarification or removal.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py`
around lines 656 - 699, Hyena20bModelProvider defines an incorrect/unused
attribute short_conv_len and an orphan hyena_out_proj_bias; remove
short_conv_len (it duplicates/typoed counterpart hyena_short_conv_len) and
delete hyena_out_proj_bias unless you add a corresponding field in HyenaConfig
and wire it into the model code; update Hyena20bModelProvider by removing the
short_conv_len and hyena_out_proj_bias declarations (or rename short_conv_len to
hyena_short_conv_len only if HyenaConfig lacks that field and you also add it
there).
| MODEL_OPTIONS: dict[str, object] = {**HYENA_MODEL_OPTIONS, **EDEN_MODEL_OPTIONS} | ||
|
|
||
|
|
||
| def infer_model_type(model_size: str) -> str: | ||
| """Infer the model architecture type from the model size key. | ||
|
|
||
| Returns: | ||
| "hyena" if the key is in HYENA_MODEL_OPTIONS, "eden" if in EDEN_MODEL_OPTIONS. | ||
|
|
||
| Raises: | ||
| ValueError: If the key is not found in any model options dict. | ||
| """ | ||
| if model_size in HYENA_MODEL_OPTIONS: | ||
| return "hyena" | ||
| elif model_size in EDEN_MODEL_OPTIONS: | ||
| return "eden" | ||
| else: | ||
| raise ValueError(f"Unknown model size: {model_size!r}. Valid options: {sorted(MODEL_OPTIONS.keys())}") |
There was a problem hiding this comment.
Potential inconsistency if key collision occurs between HYENA and EDEN options.
MODEL_OPTIONS is created via {**HYENA_MODEL_OPTIONS, **EDEN_MODEL_OPTIONS}, so if a key exists in both, Eden's value wins. However, infer_model_type checks HYENA_MODEL_OPTIONS first, so it would return "hyena" for a colliding key while MODEL_OPTIONS[key] returns the Eden provider.
Consider adding a runtime assertion to detect collisions:
Suggested collision check
MODEL_OPTIONS: dict[str, object] = {**HYENA_MODEL_OPTIONS, **EDEN_MODEL_OPTIONS}
+
+# Ensure no key collisions between Hyena and Eden options
+_colliding_keys = set(HYENA_MODEL_OPTIONS.keys()) & set(EDEN_MODEL_OPTIONS.keys())
+if _colliding_keys:
+ raise ValueError(f"Key collision between HYENA and EDEN model options: {_colliding_keys}")Also, the docstring for infer_model_type is missing the Args section per Google-style convention.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py`
around lines 1115 - 1132, The bug: MODEL_OPTIONS is built as
{**HYENA_MODEL_OPTIONS, **EDEN_MODEL_OPTIONS} which lets EDEN override HYENA on
key collision, but infer_model_type checks HYENA first causing inconsistent
behavior; fix by adding a runtime collision check after constructing
MODEL_OPTIONS that computes collisions = set(HYENA_MODEL_OPTIONS) &
set(EDEN_MODEL_OPTIONS) and either raise a clear ValueError (or log and resolve
to a chosen precedence) if collisions is non-empty, and update infer_model_type
to rely on MODEL_OPTIONS (or document the chosen precedence) so behavior is
consistent; also update infer_model_type's docstring to include an Args section
describing the model_size parameter.
| if not is_hyena: | ||
| model_provider._pg_collection = ProcessGroupCollection.use_mpu_process_groups() | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check how _pg_collection is used in the model provider hierarchy
rg -n "_pg_collection" --type py -C 3Repository: NVIDIA/bionemo-framework
Length of output: 3544
🏁 Script executed:
# Find the model_provider class definition and type
rg -n "model_provider\s*=" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py -B 3 -A 1 | head -50Repository: NVIDIA/bionemo-framework
Length of output: 364
🏁 Script executed:
# Check the model_provider class definition
rg -n "class.*Provider" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/ -A 10 | grep -A 10 "_pg_collection"Repository: NVIDIA/bionemo-framework
Length of output: 50
🏁 Script executed:
# Look for any setter method or property for _pg_collection
rg -n "def.*pg_collection|_pg_collection\s*=" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/ -B 2 -A 2Repository: NVIDIA/bionemo-framework
Length of output: 50
🏁 Script executed:
# Find the model_provider class - check evo2_provider.py for class definition
cat bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py | head -100Repository: NVIDIA/bionemo-framework
Length of output: 4585
🏁 Script executed:
# Check for _pg_collection initialization in __init__ or class attributes
rg -n "_pg_collection" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py -B 5 -A 5Repository: NVIDIA/bionemo-framework
Length of output: 1161
🏁 Script executed:
# Look for how ProcessGroupCollection is used - find imports and usage
rg -n "ProcessGroupCollection" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py -B 5 -A 5Repository: NVIDIA/bionemo-framework
Length of output: 1140
🏁 Script executed:
# Check the context around finalize() call - what does it do?
rg -n "def finalize" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/ -A 15Repository: NVIDIA/bionemo-framework
Length of output: 50
🏁 Script executed:
# Search for any public setter methods or properties related to pg_collection in the provider
rg -n "def.*pg_collection|@property" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py -B 2 -A 5Repository: NVIDIA/bionemo-framework
Length of output: 492
🏁 Script executed:
# Check if _pg_collection is accessed anywhere else or how it's used after being set
rg -n "\._pg_collection" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py -B 3 -A 3Repository: NVIDIA/bionemo-framework
Length of output: 292
🏁 Script executed:
# Search for any comments or TODOs around this pattern
rg -n "pg_collection|process.*group" bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py -B 5 -A 5 | head -80Repository: NVIDIA/bionemo-framework
Length of output: 828
🏁 Script executed:
# Check if there are any comments in the test helper file that explain why this is needed
rg -n "pg_collection" bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/_eden_roundtrip_helper.py -B 10 -A 5Repository: NVIDIA/bionemo-framework
Length of output: 663
🏁 Script executed:
# Look for use_mpu_process_groups method definition or documentation
rg -n "use_mpu_process_groups|ProcessGroupCollection" bionemo-recipes/recipes/evo2_megatron/src/ -B 2 -A 10 | head -100Repository: NVIDIA/bionemo-framework
Length of output: 13059
Consider using a public API or setter method for _pg_collection configuration.
The direct assignment to model_provider._pg_collection works but sets a private attribute, which is fragile if the base class implementation changes. Unlike Hyena models that configure process groups internally through component constructors, non-Hyena models require this external setup. Consider whether a public setter method or initialization parameter would make this pattern more maintainable.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py` around
lines 334 - 336, You're directly mutating the private attribute
model_provider._pg_collection with
ProcessGroupCollection.use_mpu_process_groups(), which is fragile; instead
expose and use a public API or setter on the provider (e.g., add or call a
method like set_process_group_collection or a constructor/init parameter on the
ModelProvider class) so non-Hyena models can be configured without touching
internals—update the provider implementation to accept and store the
ProcessGroupCollection via that public method and replace the direct assignment
at the call site with the new setter or init call, keeping Hyena models'
internal behavior unchanged.
| latest_file = mbridge_ckpt_dir / "latest_checkpointed_iteration.txt" | ||
| if latest_file.exists(): | ||
| iteration = latest_file.read_text().strip() | ||
| iter_dir = mbridge_ckpt_dir / f"iter_{int(iteration):07d}" | ||
| else: | ||
| iter_dirs = sorted(mbridge_ckpt_dir.glob("iter_*")) | ||
| if not iter_dirs: | ||
| raise FileNotFoundError(f"No iter_* directories in {mbridge_ckpt_dir}") | ||
| iter_dir = iter_dirs[-1] |
There was a problem hiding this comment.
Handle direct iter_* checkpoint paths here.
This resolver only works when --mbridge-ckpt-dir points at the checkpoint root. The new docs in this PR also show passing .../iter_0000001 directly, and that currently falls through to glob("iter_*") on the iteration directory itself and raises FileNotFoundError even though the checkpoint is valid.
🛠️ Possible fix
def load_mbridge_state_dict(mbridge_ckpt_dir: Path) -> dict[str, torch.Tensor]:
@@
- latest_file = mbridge_ckpt_dir / "latest_checkpointed_iteration.txt"
- if latest_file.exists():
- iteration = latest_file.read_text().strip()
- iter_dir = mbridge_ckpt_dir / f"iter_{int(iteration):07d}"
- else:
- iter_dirs = sorted(mbridge_ckpt_dir.glob("iter_*"))
- if not iter_dirs:
- raise FileNotFoundError(f"No iter_* directories in {mbridge_ckpt_dir}")
- iter_dir = iter_dirs[-1]
+ if mbridge_ckpt_dir.name.startswith("iter_"):
+ iter_dir = mbridge_ckpt_dir
+ else:
+ latest_file = mbridge_ckpt_dir / "latest_checkpointed_iteration.txt"
+ if latest_file.exists():
+ iteration = latest_file.read_text().strip()
+ iter_dir = mbridge_ckpt_dir / f"iter_{int(iteration):07d}"
+ else:
+ iter_dirs = sorted(mbridge_ckpt_dir.glob("iter_*"))
+ if not iter_dirs:
+ raise FileNotFoundError(f"No iter_* directories in {mbridge_ckpt_dir}")
+ iter_dir = iter_dirs[-1]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| latest_file = mbridge_ckpt_dir / "latest_checkpointed_iteration.txt" | |
| if latest_file.exists(): | |
| iteration = latest_file.read_text().strip() | |
| iter_dir = mbridge_ckpt_dir / f"iter_{int(iteration):07d}" | |
| else: | |
| iter_dirs = sorted(mbridge_ckpt_dir.glob("iter_*")) | |
| if not iter_dirs: | |
| raise FileNotFoundError(f"No iter_* directories in {mbridge_ckpt_dir}") | |
| iter_dir = iter_dirs[-1] | |
| if mbridge_ckpt_dir.name.startswith("iter_"): | |
| iter_dir = mbridge_ckpt_dir | |
| else: | |
| latest_file = mbridge_ckpt_dir / "latest_checkpointed_iteration.txt" | |
| if latest_file.exists(): | |
| iteration = latest_file.read_text().strip() | |
| iter_dir = mbridge_ckpt_dir / f"iter_{int(iteration):07d}" | |
| else: | |
| iter_dirs = sorted(mbridge_ckpt_dir.glob("iter_*")) | |
| if not iter_dirs: | |
| raise FileNotFoundError(f"No iter_* directories in {mbridge_ckpt_dir}") | |
| iter_dir = iter_dirs[-1] |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/mbridge_to_vortex.py`
around lines 48 - 56, The current logic assumes mbridge_ckpt_dir is the
checkpoint root and fails if the user passes an iter_* directory; modify the
resolver around latest_file/iter_dir so that if mbridge_ckpt_dir.name matches
the iter_* pattern (e.g., startswith "iter_" or matches r"^iter_\d+$") you treat
mbridge_ckpt_dir itself as the iter_dir; otherwise keep the existing flow (check
for latest_checkpointed_iteration.txt, parse iteration into iter_{:07d}, or
fallback to glob("iter_*")). Update uses of iteration/iter_dirs to reflect this
early-path selection so valid direct iter_* paths are accepted.
| embed_key = "embedding.word_embeddings.weight" | ||
| if embed_key in mbridge_state_dict: | ||
| vortex_sd["embedding_layer.weight"] = mbridge_state_dict[embed_key] | ||
| vortex_sd["unembed.weight"] = mbridge_state_dict[embed_key] | ||
|
|
||
| for layer_idx, symbol in enumerate(pattern): | ||
| prefix = f"decoder.layers.{layer_idx}" | ||
| block_prefix = f"blocks.{layer_idx}" | ||
|
|
||
| if symbol != "*": | ||
| _convert_hyena_layer( | ||
| mbridge_state_dict, | ||
| vortex_sd, | ||
| prefix, | ||
| block_prefix, | ||
| symbol, | ||
| te_enabled, | ||
| num_groups, | ||
| filter_order, | ||
| medium_conv_len, | ||
| ) | ||
| else: | ||
| _convert_attention_layer( | ||
| mbridge_state_dict, | ||
| vortex_sd, | ||
| prefix, | ||
| block_prefix, | ||
| te_enabled, | ||
| ) | ||
|
|
||
| _convert_mlp(mbridge_state_dict, vortex_sd, prefix, block_prefix, te_enabled) | ||
|
|
||
| final_norm_key = "decoder.final_norm.weight" | ||
| if final_norm_key in mbridge_state_dict: | ||
| vortex_sd["norm.scale"] = mbridge_state_dict[final_norm_key] |
There was a problem hiding this comment.
Validate required tensor keys before writing the export.
Every mapping here is optional, so a bad --model-size or --no-te choice can silently drop required tensors and still produce a .pt file plus config.json. For a format converter, that makes corruption very hard to catch. Please collect missing mandatory keys per layer and raise before saving.
Also applies to: 172-296
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/mbridge_to_vortex.py`
around lines 133 - 167, The export currently silently omits required tensors
(e.g., embedding.word_embeddings.weight, decoder.final_norm.weight and per-layer
weights produced by _convert_hyena_layer, _convert_attention_layer, and
_convert_mlp) so add a validation pass after the loop that defines the mandatory
target keys (embedding_layer.weight, unembed.weight, norm.scale and all expected
per-layer keys derived from prefix/block_prefix and the pattern) and check their
presence in the resulting vortex_sd (or mbridge_state_dict if conversions expect
source keys); collect missing keys into a list and raise a descriptive exception
listing layer and key names (including references to the layer index and symbol)
before writing the .pt/config.json to fail fast on bad --model-size or --no-te
choices.
| @pytest.fixture(scope="module") | ||
| def savanna_checkpoint_path(tmp_path_factory): | ||
| """Download the 1b savanna checkpoint from HuggingFace.""" | ||
| cache_dir = tmp_path_factory.mktemp("savanna_ckpt") | ||
| path = hf_hub_download( | ||
| repo_id=SAVANNA_1B_REPO, | ||
| filename="savanna_evo2_1b_base.pt", | ||
| local_dir=str(cache_dir), | ||
| ) | ||
| return path | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def vortex_reference_path(tmp_path_factory): | ||
| """Download the 1b vortex checkpoint from HuggingFace.""" | ||
| cache_dir = tmp_path_factory.mktemp("vortex_ref") | ||
| path = hf_hub_download( | ||
| repo_id=VORTEX_1B_REPO, | ||
| filename="evo2_1b_base.pt", | ||
| local_dir=str(cache_dir), | ||
| ) | ||
| return path | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def roundtrip_vortex_sd(savanna_checkpoint_path): | ||
| """Perform savanna -> mbridge -> vortex conversion and return the vortex state dict.""" | ||
| provider_cls = HYENA_MODEL_OPTIONS[MODEL_SIZE] | ||
| model_provider = provider_cls() | ||
| pattern = model_provider.hybrid_override_pattern | ||
|
|
||
| savanna_sd = load_savanna_state_dict(savanna_checkpoint_path) | ||
| mbridge_sd = savanna_to_mbridge_state_dict(savanna_sd, pattern, te_enabled=True) | ||
| vortex_sd = mbridge_to_vortex_state_dict(mbridge_sd, model_provider, te_enabled=True) | ||
| return vortex_sd | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def vortex_reference_sd(vortex_reference_path): | ||
| """Load the reference vortex state dict from HuggingFace.""" | ||
| return torch.load(vortex_reference_path, map_location="cpu", weights_only=False) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate the file and verify it exists
find . -type f -name "test_checkpoint_roundtrip.py" | head -20Repository: NVIDIA/bionemo-framework
Length of output: 155
🏁 Script executed:
# Check if pyproject.toml exists and review relevant configs
find . -type f -name "pyproject.toml" | head -5Repository: NVIDIA/bionemo-framework
Length of output: 294
🏁 Script executed:
# Read the test file to verify the code snippet
find . -type f -name "test_checkpoint_roundtrip.py" -exec cat -n {} \; 2>/dev/null | head -100Repository: NVIDIA/bionemo-framework
Length of output: 4683
🏁 Script executed:
# Check if these tests are actually run in CI
grep -r "CI\|GITHUB_ACTIONS\|skipif" "./bionemo-recipes/recipes/evo2_megatron/tests/" --include="*.py" | head -20Repository: NVIDIA/bionemo-framework
Length of output: 3359
🏁 Script executed:
# Look for other hf_hub_download usage patterns in the codebase
rg "hf_hub_download" -A 5 -B 2 --type py | head -80Repository: NVIDIA/bionemo-framework
Length of output: 9229
🏁 Script executed:
# Check for torch.load patterns with weights_only in the codebase
rg "torch\.load.*weights_only" -B 2 -A 2 --type py | head -60Repository: NVIDIA/bionemo-framework
Length of output: 6234
🏁 Script executed:
# Look at the test markers to understand test context
grep -n "pytest.mark\|@pytest" "./bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_checkpoint_roundtrip.py"Repository: NVIDIA/bionemo-framework
Length of output: 542
🏁 Script executed:
# Check the load_savanna_state_dict implementation
rg "def load_savanna_state_dict" -A 15 --type pyRepository: NVIDIA/bionemo-framework
Length of output: 2041
🏁 Script executed:
# Look at the full test to understand what comparisons are being done
sed -n '82,150p' "./bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_checkpoint_roundtrip.py"Repository: NVIDIA/bionemo-framework
Length of output: 2470
Pin HuggingFace checkpoints to immutable revisions and load reference checkpoint safely.
The fixtures download Savanna and Vortex checkpoints without specifying a revision=, so the golden data drifts if upstream repos change. Additionally, vortex_reference_sd() uses weights_only=False when loading a remote .pt file, unnecessarily allowing pickle code execution in the test. Add explicit commit SHAs to both hf_hub_download() calls and change the torch.load to use weights_only=True for the reference fixture (or use safetensors format if available).
Example fix
path = hf_hub_download(
repo_id=SAVANNA_1B_REPO,
filename="savanna_evo2_1b_base.pt",
+ revision="<commit-sha>",
local_dir=str(cache_dir),
)
path = hf_hub_download(
repo_id=VORTEX_1B_REPO,
filename="evo2_1b_base.pt",
+ revision="<commit-sha>",
local_dir=str(cache_dir),
)
- return torch.load(vortex_reference_path, map_location="cpu", weights_only=False)
+ return torch.load(vortex_reference_path, map_location="cpu", weights_only=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_checkpoint_roundtrip.py`
around lines 39 - 79, The fixtures download mutable HuggingFace checkpoints and
allow unsafe pickle loading; update savanna_checkpoint_path and
vortex_reference_path to pass explicit immutable commit SHAs via the revision=
parameter in their hf_hub_download(...) calls (use the specific commit SHA for
SAVANNA_1B_REPO and VORTEX_1B_REPO to pin the golden data), and modify
vortex_reference_sd to load the reference safely by calling torch.load(...,
map_location="cpu", weights_only=True) (or prefer safetensors if a .safetensors
artifact exists) so remote .pt files cannot execute pickle code.
| @pytest.mark.slow | ||
| def test_roundtrip_prediction_equality( | ||
| eden_ckpt: Path, | ||
| hf_exported_dir: Path, | ||
| hf_reimported_dir: Path, | ||
| tmp_path, | ||
| ): | ||
| """Verify that predictions from the original and roundtripped models match. | ||
|
|
||
| Runs predict on both the original mbridge checkpoint and on the re-imported HF checkpoint | ||
| (loaded via AutoBridge) and compares per-token log probabilities. | ||
| """ | ||
| num_sequences = 2 | ||
| seq_lengths = [64, 64] | ||
|
|
||
| fasta_path = tmp_path / "test.fasta" | ||
| create_fasta_file(fasta_path, num_sequences, sequence_lengths=seq_lengths, repeating_dna_pattern=ALU_SEQUENCE) | ||
|
|
||
| env = copy.deepcopy(PRETEST_ENV) | ||
| if is_a6000_gpu(): | ||
| env["NCCL_P2P_DISABLE"] = "1" | ||
|
|
||
| # Predictions from the original mbridge checkpoint | ||
| original_preds = _run_predict(eden_ckpt, fasta_path, tmp_path / "orig_preds", env) | ||
|
|
||
| assert "log_probs_seqs" in original_preds | ||
| assert "seq_idx" in original_preds | ||
|
|
||
| # Load the original and reimported HF models and compare forward pass | ||
| original_hf = LlamaForCausalLM.from_pretrained(hf_exported_dir, torch_dtype=torch.bfloat16).eval() | ||
| reimported_hf = LlamaForCausalLM.from_pretrained(hf_reimported_dir, torch_dtype=torch.bfloat16).eval() | ||
|
|
||
| # Quick sanity: HF forward pass should produce identical outputs for both | ||
| input_ids = torch.randint(0, 256, (1, 32)) | ||
| with torch.no_grad(): | ||
| orig_logits = original_hf(input_ids).logits | ||
| reimp_logits = reimported_hf(input_ids).logits | ||
|
|
||
| torch.testing.assert_close(orig_logits, reimp_logits, atol=0, rtol=0) |
There was a problem hiding this comment.
Test computes original_preds but only compares HF models to each other.
The test runs _run_predict on eden_ckpt (lines 170-174) but only asserts that keys exist. The actual comparison (lines 183-186) is between original_hf and reimported_hf forward passes, not involving original_preds.
If the intent is to verify that the roundtripped model produces the same predictions as the original mbridge checkpoint, the comparison should include original_preds. Otherwise, the docstring claim "compares per-token log probabilities" is misleading since only HF-to-HF comparison occurs.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_eden_llama_roundtrip.py`
around lines 148 - 186, The test currently computes original_preds via
_run_predict(eden_ckpt, ...) but then only compares original_hf and
reimported_hf logits; change the test to run _run_predict on the roundtripped HF
checkpoint (use hf_reimported_dir) to produce hf_preds and then compare
original_preds["log_probs_seqs"] to hf_preds["log_probs_seqs"] (or the
equivalent per-token log-prob key) using a numeric assert (e.g.,
torch.testing.assert_close or numpy.testing.assert_allclose) so the comparison
actually verifies the roundtrip predictions; locate and update the block that
currently creates original_hf/reimported_hf and the final
torch.testing.assert_close to instead call _run_predict for hf_reimported_dir
(or both HF and eden if you want) and perform the assertion on the
"log_probs_seqs" entries.
Signed-off-by: John St. John <jstjohn@nvidia.com>
…hn/evo2_llama_configs_and_savanna_convert
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Signed-off-by: John St. John <jstjohn@nvidia.com>
Description
This PR adds Eden (Llama 3.1) model support, Savanna/Vortex checkpoint converters, and a standardized model naming convention to the Megatron Bridge–based Evo2 recipe (
bionemo-recipes/recipes/evo2_megatron/).Eden (Llama 3.1) model support
eden_provider.pydefiningEdenModelProviderand size-specific subclasses (eden_7bthrougheden_35b) that inherit fromLlama31ModelProvider.train.pynow dispatches togpt_forward_stepfor Eden models and automatically disablesfp32_residual_connection(incompatible with standard TELayerNormLinearlayers — Hyena handles this via manual dtype casting, but GPT/Llama does not).infer.pynow initializesProcessGroupCollectionfor non-Hyena providers (required byGPTModelProvider.provide()) and usesStaticInferenceContextinstead ofHyenaInferenceContextfor Eden models. Theflash_decodeattribute is guarded to Hyena-only.predict.pyalready worked architecture-agnostically via dynamic model loading; no changes required.Checkpoint converters
savanna_to_mbridge.py— converts ARC Savanna.ptcheckpoints (local or downloaded from Hugging Face viahf_hub_download) into MBridge distributed checkpoint format.mbridge_to_vortex.py— exports MBridge checkpoints to ARC's single-file Vortex inference format, handling MLP weight splitting, Hyena filter pole/residue computation, and TE layernorm key remapping.evo2_convert_savanna_to_mbridge,evo2_export_mbridge_to_vortex).Model naming convention
The previous model size keys (
1b,7b,40b,7b_arc_longcontext, …) were ambiguous —7breferred to Striped Hyena while7Breferred to Llama. This PR replaces them with explicit, architecture-prefixed keys:evo2_*for models matching public ARC checkpoints (e.g.evo2_1b_base,evo2_7b,evo2_40b_base)._base= 8K context, without it = 1M context.striped_hyena_*_nvfor NVIDIA-modified Hyena variants.eden_*for Llama 3.1 variants.evo2_20bconfig based onarcinstitute/savanna_evo2_20b.Documentation updates
README.md— added model naming convention tables, Vortex export section with round-trip example, updated all CLI examples to new model keys.checkpoint/README.md— updated--model-sizedocumentation.zeroshot_brca1.ipynb,fine-tuning-tutorial.ipynb) — updatedMODEL_SIZEand--model-sizereferences.Usage
Training an Eden model:
Converting Savanna checkpoint to MBridge:
Exporting MBridge to Vortex:
Type of changes
CI Pipeline Configuration
Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.
Unit tests marked as
@pytest.mark.multi_gpuor@pytest.mark.distributedare not run in the PR pipeline.For more details, see CONTRIBUTING
Note
By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.
Authorizing CI Runs
We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.
automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
/ok to testcomment on the pull request to trigger CI. This will need to be done for each new commit.Triggering Code Rabbit AI Review
To trigger a code review from code rabbit, comment on a pull request with one of these commands:
See https://docs.coderabbit.ai/reference/review-commands for a full list of commands.
Pre-submit Checklist
Summary by CodeRabbit
New Features
Documentation
evo2_1b_base,evo2_7b).Tests